"""
Created on 2021/12/15
note  : Parameter Configuration
author: Yuze Xuan, Xiaohu Hao, Xuan Wang, Sida Wang
"""

import os
import warnings

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
warnings.filterwarnings("ignore")

# * 参数设置
IMAGE_SIZE = (256, 256)  # width, height
CHANNELS = 3  # 图像通道数，一般为RGB
IMAGE_EXT = ['.png', '.jpg', '.jpeg']  # 图像格式（扩展名）信息列表
DATASET_BASE_PATH = 'dataset'  # 数据集路径
TRAIN_RATIO = 0.9  # 训练集占比，其余为验证集
SAVE_BASE_PATH = 'results'  # 结果存储根目录
TEST_INFO_NAME = 'test_info.csv'
SAVE_PRED_RESULT_BY_CAT = False  # 按类别保存模型分类结果

# 图像增强参数
IMAGE_ENHANCER_KWARGS = dict(
    crop_probability=0.5,  # 裁剪概率
    crop_fix_range=6,  # 裁剪量
    flip_probability=0.5,  # 翻转概率
    horizontal_flip=0.4,  # 水平翻转
    vertical_flip=0.4,  # 垂直翻转
    diagonal_flip=0.2,  # 对角翻转
)
# 训练超参数
TRAIN_KWARGS = dict(
    lr=0.1,  # 学习率
    batch_size=16,  # 每批次训练样本的个数
    epochs=80,  # 训练轮数
    patienceEpoch=10  # 可接受的精度连续下降轮数，用于EarlyStopping，防止过拟合
)

# * 初始路径配置
if not os.path.exists(SAVE_BASE_PATH):
    os.makedirs(SAVE_BASE_PATH)
if not os.path.exists(os.path.join(SAVE_BASE_PATH, 'checkpoints')):
    os.makedirs(os.path.join(SAVE_BASE_PATH, 'checkpoints'))
if not os.path.exists(os.path.join(SAVE_BASE_PATH, 'figures')):
    os.makedirs(os.path.join(SAVE_BASE_PATH, 'figures'))
